-
Notifications
You must be signed in to change notification settings - Fork 894
Added support for JAX for Neumann and Robin BC #2015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
| def error(self, X, inputs, outputs, beg, end, aux_var=None): | ||
| if self.batch_size is not None: | ||
| return self.func(inputs, outputs, X)[beg:end] - self.values[self.batch_indices] | ||
| return ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't modify this.
| raise NotImplementedError( | ||
| "Reverse-mode autodiff doesn't support 3D output" | ||
| ) | ||
| raise NotImplementedError("Reverse-mode autodiff doesn't support 3D output") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't modify this.
| dydx = grad.jacobian(outputs, inputs, i=self.component, j=None)[beg:end] | ||
| if backend_name == "jax": | ||
| dydx = grad.jacobian( | ||
| (outputs, self.func), inputs, i=self.component, j=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why self.func? Here we should compute the derivative of network outputs wrt network inputs. The self.func is the function in the BC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inside gradients, the Jacobian calculation needs an array with 2 elements, where the first element is the same as in the other backends (what we want to compute the gradient to), but the second element needs to be a function that computes the gradient (the BC in this case). Previously, only the output was fed to the grad.jacobian, and that was producing an error. It should have been an out of index error, in line 50 of deepxde/gradients/gradients_reverse.py and line 76 of deepxde/gradients/gradients_forward.py, but JAX does not throw an error in those instances.
Therefore, in those lines, the error was different, which made it difficult to pinpoint, and my conclusion was that the function to compute the gradient against needed to be passed forward to the Jacobian calculation.
elif backend_name == "jax":
tangent = jax.numpy.zeros(self.dim_x).at[j].set(1)
grad_fn = lambda x: jax.jvp(self.ys[1], (x,), (tangent,))[1]
self.J[j] = (jax.vmap(grad_fn)(self.xs), grad_fn)This is what was causing the error. When computing self.ys[1], instead of out of bounds, it was throwing a different error. (I believe it was something along the lines of TypeError: 'jaxlib._jax.ArrayImpl' object is not callable)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, for JAX, we need another element of the function. My question is why it is self.func? Should it be the function of the network forward pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so. We want to compute the boundaries conditions. That part of the code is called in the NeumannBC to compute the loss term, and if I am not mistaken, self.func is
In short, what I mean is that we do not want to compute the gradients w.r.t the network, as what we want is to compute a term for the loss, not the gradient of the loss that can be used for backpropagation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neumann BC is dy/dn = f(x), where dy/dn dy/dx in the n direction.
https://en.wikipedia.org/wiki/Neumann_boundary_condition
I made a few modifications in
gradients_reserve.pyandboundary_conditions.pyto enable out-of-the-box support for JAX, at least for the exampledeepxde/examples/pinn_forward/Poisson_Neumann_1d.py.I have tested the changes on a few examples, and it seems to work for without issues for both PyTorch and JAX, but the testing has not been exhaustive.